[BUGFix] Fix UINT/INT8 dequantize implementation and optimize the schedule template for float32 accum #46
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This pull request includes changes to several Python files in the
bitblas
library, with the primary goal of improving support for different data types and making the code more robust. This includes changes to thehint.py
,tensorcore.py
,lop3.py
,general_matmul.py
, andmatmul_dequantize_impl.py
files. The changes can be grouped into three main categories: updates to thehint.py
andtensorcore.py
files to handle different data types, improvements to thelop3.py
file to better handle different bit sizes, and changes to thegeneral_matmul.py
andmatmul_dequantize_impl.py
files to add assertions and handle different bit sizes.Handling different data types:
python/bitblas/base/roller/hint.py
: Updated the__repr__
method in theTensorCoreExtraConfig
class to handlefloat32
andint32
data types.python/bitblas/base/roller/policy/tensorcore.py
: Modified the_score
function to setshared_scope
to"shared.dyn"
if theout_dtype
isfloat32
.Improvements to handle different bit sizes:
python/bitblas/gpu/intrin/lop3.py
: Reformatted theget_fast_decode_intrin
function calls in theLOP3_FAST_DECODE_INT2_TO_INT8_TO_INT8_L16_INTRIN
,LOP3_FAST_DECODE_INT1_TO_INT8_TO_INT8_L16_INTRIN
, andLOP3_FAST_DECODE_INT4_TO_INT8_TO_FP16_L8_INTRIN
registrations for better readability. Also added newLOP3_FAST_DECODE_INT1_TO_INT8_TO_FP16_L8_INTRIN
andLOP3_FAST_DECODE_INT1_TO_INT8_TO_FP16_L8_SCALE_INTRIN
registrations. [1] [2] [3]Adding assertions and handling different bit sizes:
python/bitblas/ops/general_matmul.py
: Added theis_not_fast_decoding_supported
function in the__initialize_fast_decoding
method and updated the condition in thetransform_weight
method to check ifbit
is less than8
. [1] [2]python/bitblas/ops/impl/matmul_dequantize_impl.py
: Added assertions to check ifbit
is in[1, 2, 4, 8]
in thematmul_nt_dequantize_b
,matmul_nt_dequantize_b_propagate_b
, andmatmul_nt_dequantize_b_propagate_a_propagate_b
functions. Also updated thedecode_func
function in these methods to handle the case wherebit
is8
. [1] [2] [3] [4] [5] [6] [7] [8]Other changes:
3rdparty/tvm
: Updated the subproject commit.